fig, ax = plt.subplots()
ax.plot(x_range, y_range, linewidth=4)
ax.set_title('f(x) = max(x², 1) - e^(-900*(x-0.6)²) - e^(-900*(x+0.6)²) - e^(-25*x²)')

# 设置边距
plt.subplots_adjust(left=0.03, right=1, top=1, bottom=0.05)  # 根据需要调整这些参数
# 初始化点
points = {name: ax.plot([], [], marker='o', color=colors[name], markersize=20, label=name, zorder=10)[0] for name in trajectories.keys()}
# 标记起始点
start_y = func(torch.tensor(start_point)).item()  # 计算对应的 y 值
ax.scatter(start_point, start_y, color='k', s=400, label='Start Point', zorder=16)  # 标记起始点
# ax.text(start_point, start_y, 'Start Point', color='k', fontsize=12, ha='left', zorder=16)  # 添加标签
ax.set_xlabel('x')
ax.set_ylabel('f(x)')
ax.axhline(y=0, color='gray', linestyle='--')  # y=0 线
ax.legend(loc='upper left', fontsize=24, bbox_to_anchor=(0, 1), frameon=False, handletextpad=1.5,  # 增加每个图例标签与图例边界的间距
labelspacing=1.5)   # 增加标签之间的间距

# 更新函数
def update(frame):
    for name, trajectory in trajectories.items():
        if frame < len(trajectory):
            x_val, y_val= trajectory[frame]
            points[name].set_data([x_val], [y_val])  # 确保传入的是列表
    # 返回所有点的 Artist 对象
    return list(points.values())

# 创建动画
ani = FuncAnimation(fig, update, frames=len(trajectory), blit=True)

# 保存GIF
ani.save(f'{args.id}_1d_lr_{lr}.gif', 
        writer=PillowWriter(fps=image_per_second))
plt.show()